import argparse
import numpy
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch import functional as F
from torch.optim.lr_scheduler import LambdaLR
from transformers import GPT2LMHeadModel, GPT2Config


class NormalizedMutualInformationLoss(nn.Module):

    def __init__(self, ktype='gaussian', sigma=50.0):
        super(NormalizedMutualInformationLoss, self).__init__()
        self.ktype = ktype
        self.sigma = sigma

    def distmat(self, X):
        """ distance matrix
        """
        if len(X.shape) == 1:
            X = X.view(-1, 1)
        r = torch.sum(X * X, 1)
        r = r.view([-1, 1])
        a = torch.mm(X, torch.transpose(X, 0, 1))
        D = r.expand_as(a) - 2 * a + torch.transpose(r, 0, 1).expand_as(a)
        D = torch.abs(D)
        return D

    
    def kernelmat(self, X, sigma, ktype='gaussian'):
        """ kernel matrix baker
        """
        if len(X.shape) == 1:
            X = X.view(-1, 1)
        
        m = int(X.size()[0])
        H = torch.eye(m) - (1. / m) * torch.ones([m, m])

        if ktype == "gaussian":
            Dxx = self.distmat(X)

            variance = 2. * sigma * sigma * X.size()[1]
            Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor)  # kernel matrices
            # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx))
        
        elif ktype == "linear":
            Kx = torch.mm(X, X.T).type(torch.FloatTensor)

        elif ktype == 'IMQ':
            Dxx = self.distmat(X)
            Kx = 1 * torch.rsqrt(Dxx + 1)

        Kxc = torch.mm(Kx, H)

        return Kxc
    
    def hsic_normalized_cca(self, x, y, sigma=50., ktype='gaussian'):
        if len(x.shape) == 1:
            x = x.reshape(-1, 1)
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)
        # x = torch.from_numpy(x)
        # y = torch.from_numpy(y)
        
        m = int(x.size()[0])
        Kxc = self.kernelmat(x, sigma=sigma, ktype=ktype)
        Kyc = self.kernelmat(y, sigma=sigma, ktype=ktype)

        epsilon = 1E-5
        K_I = torch.eye(m)
        Kxc_i = torch.inverse(Kxc + epsilon * m * K_I)
        Kyc_i = torch.inverse(Kyc + epsilon * m * K_I)
        Rx = (Kxc.mm(Kxc_i))
        Ry = (Kyc.mm(Kyc_i))
        Pxy = torch.sum(torch.mul(Rx, Ry.t()))
        return Pxy


    def estimate_mi_hsic(self, x, y, ktype='gaussian', sigma=50.):
        estimate_IXY = self.hsic_normalized_cca(x, y, ktype=ktype, sigma=sigma)
        return estimate_IXY

    def forward(self, x, y):
        """
        x: [batch_size, num_features_x]
        y: [batch_size, num_features_y]
        """
        if len(x.shape) == 1:
            x = x.reshape(-1, 1)
        if len(y.shape) == 1:
            y = y.reshape(-1, 1)

        mi = self.estimate_mi_hsic(x, y, ktype=self.ktype, sigma=self.sigma)
        return mi
        # if mi <=0:
        #     return -torch.log( torch.exp(mi) / 100 )
        # if mi >=1:
        # return -mi

    
class MLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_size, input_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(input_size, input_size)
        self.layer3 = nn.Linear(input_size, num_classes)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.relu(x)
        x = self.layer3(x)
        return x


class BitSubsetParity(pl.LightningModule):
    def __init__(self, step_by_step: bool, num_of_bits: int, width=512, num_heads=8, depth=3, learning_rate=1e-3, warmup_steps=1000, weight_decay=1e-2, compress_steps=1, compress_type="tokenizer", evaluate_with_greedy_decoding=False):
        super().__init__()
        self.save_hyperparameters()
        self.step_by_step = step_by_step
        self.num_of_bits = num_of_bits
        self.compress_steps = compress_steps
        self.compress_type = compress_type
        self.remain_steps = int(numpy.log2(self.num_of_bits // 2)) - self.compress_steps

        self.evaluate_with_greedy_decoding = evaluate_with_greedy_decoding
        self.generation_length = self.num_of_bits + self.compress_steps + 2*(self.compress_steps>0) + 2 ** (int(numpy.log2(self.num_of_bits // 2)) - self.compress_steps) - 1 - 1
        if self.compress_type == "self_distill_mod_1":
            self.generation_length = ((self.num_of_bits * 3) // 2 - 2) if self.step_by_step else self.num_of_bits

        self.model = GPT2LMHeadModel(GPT2Config(vocab_size=7, n_positions=self.generation_length, n_embd=width, n_layer=depth, n_head=num_heads, resid_pdrop=0, embd_pdrop=0, attn_pdrop=0, bos_token_id=2, eos_token_id=2))

        if self.compress_type == "self_distill_mod_1":
            self.thought_decoder = nn.Linear(width, width)
            self.out2in_decoder = nn.Linear(width, width)
            self.NMI_loss = NormalizedMutualInformationLoss(ktype='gaussian', sigma=50.0)
        
        if self.compress_type == "self_distill_mod_2":
            self.binary_sequence = torch.tensor([1, 2, 4, 8, 16], dtype=torch.long).to(self.device)
            num_of_compress_classes = 2 ** (self.num_of_bits // 4)
            self.thought_decoder = MLP(width, num_of_compress_classes)
            self.out2in_decoder = nn.Linear(width, width)

        self.loss = nn.CrossEntropyLoss()
        self.decoder_loss = nn.MSELoss()
        # self.validation_step_outputs = []
    
    def forward(self, inputs):
        inputs=inputs.long()
        if self.step_by_step:
            do_sample = not self.evaluate_with_greedy_decoding
            inputs = self.model.generate(inputs, do_sample=do_sample, max_length=self.generation_length, min_length=self.generation_length, pad_token_id=2, num_beams=1)
        logits = self.model(inputs).logits[:, self.num_of_bits - 1:, :2]
        predictions = torch.argmax(logits, dim=2)[:, 0]
        return predictions

    def _embeddings_training_evaluation_common(self, input_embeddings, batch):
        batch['label'] = batch['label'].long()
        logits = self.model(inputs_embeds=input_embeddings).logits[:, -batch['label'].shape[1]:, :2]
        loss = self.loss(logits.permute(0, 2, 1), batch['label'])
        predictions = torch.argmax(logits, dim=2)
        accuracy_with_steps = torch.mean((predictions == batch['label']).float())
        final_label_accuracy =torch.mean((predictions[:, -1] == batch['label'][:, -1]).float())
        return loss, final_label_accuracy, accuracy_with_steps
    
    def _self_distill_mod_1_training_evaluation_common(self, input_embeddings, think_tokens, batch):
        batch['label'] = batch['label'].long()
        batch['all_steps'] = batch['all_steps'].long()
        logits = self.model(inputs_embeds=input_embeddings).logits[:, -batch['label'].shape[1]:, :2]
        loss = self.loss(logits.permute(0, 2, 1), batch['label'])
        predictions = torch.argmax(logits, dim=2)
        accuracy_with_steps = torch.mean((predictions == batch['label']).float())
        final_label_accuracy =torch.mean((predictions[:, -1] == batch['label'][:, -1]).float())

        # Self-distillation
        step_length = self.num_of_bits // 2
        total_length = self.num_of_bits
        for idx in range(self.compress_steps):
            step_length = step_length // 2
            total_length += step_length
            pred_steps = batch['all_steps'][:, : total_length]
            trueth_outputs = self.model.transformer(pred_steps)
            step_loss = self.decoder_loss(
                self.thought_decoder(trueth_outputs.last_hidden_state[:, -1, :]),
                think_tokens[idx][:, -1, :]
            )
            loss += step_loss
        # Out-to-in transformer
        input_range = torch.arange(7).to(input_embeddings.device)
        o2i_loss = self.decoder_loss(
            self.out2in_decoder(self.model.transformer.wte(input_range)),
            self.model.lm_head.weight[input_range]
        )
        loss += o2i_loss

        return loss, final_label_accuracy, accuracy_with_steps
    
    def _self_distill_mod_2_training_evaluation_common(self, input_embeddings, think_tokens, batch):
        batch['label'] = batch['label'].long()
        batch['all_steps'] = batch['all_steps'].long()
        self.binary_sequence = self.binary_sequence.to(input_embeddings.device)
        logits = self.model(inputs_embeds=input_embeddings).logits[:, -batch['label'].shape[1]:, :2]
        loss = self.loss(logits.permute(0, 2, 1), batch['label'])
        predictions = torch.argmax(logits, dim=2)
        accuracy_with_steps = torch.mean((predictions == batch['label']).float())
        final_label_accuracy =torch.mean((predictions[:, -1] == batch['label'][:, -1]).float())
        # Self-distillation
        step_length = self.num_of_bits // 2
        total_length = self.num_of_bits
        compress_acc = [accuracy_with_steps]
        for idx in range(self.compress_steps):
            step_length = step_length // 2

            compress_pred = self.thought_decoder(think_tokens[idx][:, -1, :])
            compress_label = (batch['all_steps'][:, total_length: total_length + step_length] * self.binary_sequence[:step_length]).sum(dim=1)
            loss += self.loss(compress_pred, compress_label)
            compress_acc.append(torch.mean((torch.argmax(compress_pred, dim=1) == compress_label).float()))
            
            total_length += step_length

        # Out-to-in transformer
        # input_range = torch.arange(7).to(input_embeddings.device)
        # o2i_loss = self.decoder_loss(
        #     self.out2in_decoder(self.model.transformer.wte(input_range)),
        #     self.model.lm_head.weight[input_range]
        # )
        # loss += o2i_loss

        return loss, final_label_accuracy, sum(compress_acc) / len(compress_acc)

    def _training_evaluation_common(self, batch):
        batch['label'] = batch['label'].long()
        logits = self.model(batch['input_ids'].long()).logits[:, -batch['label'].shape[1]:, :2]
        loss = self.loss(logits.permute(0, 2, 1), batch['label'])
        predictions = torch.argmax(logits, dim=2)
        accuracy_with_steps = torch.mean((predictions == batch['label']).float())
        final_label_accuracy =torch.mean((predictions[:, -1] == batch['label'][:, -1]).float())
        return loss, final_label_accuracy, accuracy_with_steps

    def training_step(self, batch, batch_idx):
        if self.compress_type == "autoregression" and self.compress_steps > 0:
            input_data = self._generate_compressed_hidden(batch)
            loss, final_label_accuracy, accuracy_with_steps = self._embeddings_training_evaluation_common(input_data['input_embeddings'], batch)
        elif "self_distill_mod_1" in self.compress_type and self.compress_steps > 0:
            input_data = self._generate_compressed_hidden(batch)
            loss, final_label_accuracy, accuracy_with_steps = self._self_distill_mod_1_training_evaluation_common(input_data['input_embeddings'], input_data['think_tokens'] ,batch)

        elif "self_distill_mod_2" in self.compress_type and self.compress_steps > 0:
            input_data = self._generate_compressed_hidden(batch)
            loss, final_label_accuracy, accuracy_with_steps = self._self_distill_mod_2_training_evaluation_common(input_data['input_embeddings'], input_data['think_tokens'] ,batch)

        else:
            loss, final_label_accuracy, accuracy_with_steps = self._training_evaluation_common(batch)
        self.log("loss/train", loss)
        self.log("accuracy/train", final_label_accuracy)
        if self.step_by_step:
            self.log("accuracy_with_steps/train", accuracy_with_steps)
        return loss
    
    def _generate_compressed_hidden(self, batch):
        inputs = batch['input_ids'].long()
        think_tokens = []

        start_add_start_tokens = inputs[:, :self.num_of_bits+1].clone()
        start_embeddings = self.model.transformer.wte(start_add_start_tokens)
        inputs_embeds = start_embeddings
        for _ in range(self.compress_steps):
            output = self.model.transformer(inputs_embeds=inputs_embeds, output_hidden_states=True)
            think_tokens.append(output.last_hidden_state[:, -1:, :])

            in_token = think_tokens[-1]
            # if "self_distill" in self.compress_type:
            #     in_token = self.out2in_decoder(in_token)

            inputs_embeds = torch.cat((inputs_embeds, in_token), dim=1)

        input_embeddings = torch.cat((inputs_embeds, self.model.transformer.wte(inputs[:, inputs_embeds.shape[1]:])), dim=1)
        return {
            'input_embeddings': input_embeddings, 
            "think_tokens": think_tokens,
            'batch': batch, 
        }
    
    def _generate_steps(self, input_embeddings):
        inputs_embeds = input_embeddings
        if self.step_by_step and self.remain_steps > 1:
            for _ in range(2**self.remain_steps-2):
                output = self.model.transformer(inputs_embeds=inputs_embeds, output_hidden_states=True)
                logit = self.model.lm_head(output.last_hidden_state[:, -1:, :]).argmax(dim=2)
                inputs_embeds = torch.cat((inputs_embeds, self.model.transformer.wte(logit)), dim=1)
        return inputs_embeds

    def _prepare_batch_for_evaluation(self, batch):
        if self.step_by_step and self.remain_steps > 1:
            do_sample = not self.evaluate_with_greedy_decoding
            batch['input_ids'] = self.model.generate(batch['input_ids'].long(), do_sample=do_sample, max_length=self.generation_length, min_length=self.generation_length, pad_token_id=2).detach()
        return batch

    def validation_step(self, batch, batch_idx):
        if self.compress_type in ["autoregression", "self_distill_mod_1", "self_distill_mod_2"] and self.compress_steps > 0:
            # Generate compressed hidden states for autoregressive compression
            input_data = self._generate_compressed_hidden(batch)
            input_embeddings = self._generate_steps(input_data['input_embeddings'])
            if self.compress_type == "self_distill_mod_2":
                loss, accuracy, accuracy_with_steps = self._self_distill_mod_2_training_evaluation_common(input_data['input_embeddings'], input_data['think_tokens'] ,batch)
            else:
                loss, accuracy, accuracy_with_steps = self._embeddings_training_evaluation_common(input_embeddings, input_data['batch'])
        else:
            loss, accuracy, accuracy_with_steps = self._training_evaluation_common(self._prepare_batch_for_evaluation(batch))
        self.log("val_loss", loss)
        self.log("loss/val", loss)
        self.log("accuracy/val", accuracy)
        self.log("accuracy_with_steps/val", accuracy_with_steps)
        # self.validation_step_outputs.append({
        #     "loss": loss.detach().item(),
        #     "accuracy": accuracy.detach().item(),
        #     "accuracy_with_steps": accuracy_with_steps.detach().item(),
        # })
    
    # def on_validation_epoch_end(self):
    #     avg_loss = numpy.mean([x['loss'] for x in self.validation_step_outputs])
    #     avg_accuracy = numpy.mean([x['accuracy'] for x in self.validation_step_outputs])
    #     avg_accuracy_with_steps = numpy.mean([x['accuracy_with_steps'] for x in self.validation_step_outputs])
    #     self.log("val_mean_loss", avg_loss)
    #     self.log("accuracy/val_mean", avg_accuracy)
    #     if self.step_by_step:
    #         self.log("accuracy_with_steps/val_mean", avg_accuracy_with_steps)
    #     self.validation_step_outputs.clear()

    def test_step(self, batch, batch_idx):
        loss, accuracy, _ = self._training_evaluation_common(self._prepare_batch_for_evaluation(batch))
        self.log("loss/test", loss)
        self.log("accuracy/test", accuracy)

    def configure_optimizers(self):
        parameters = self.model.parameters()
        optimizer = torch.optim.Adam(parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        start_factor=1e-2
        lr_lambda = lambda epoch: (start_factor +
                (1. - start_factor) * min(self.hparams.warmup_steps, epoch) / self.hparams.warmup_steps)
        lr_scheduler = LambdaLR(optimizer, lr_lambda)        
        lr_scheduler_config = {
            "scheduler": lr_scheduler,
            "interval": "step",
            "frequency": 1,
            "monitor": "val_loss",
            "strict": True,
            "name": None,
        }
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler_config
        }

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--width', type=int, default=512)
        parser.add_argument('--num_heads', type=int, default=8)
        parser.add_argument('--depth', type=int, default=3)
        parser.add_argument('--learning_rate', type=float, default=0.001)
        parser.add_argument('--warmup_steps', type=int, default=1000)
        parser.add_argument('--weight_decay', type=float, default=1e-2)
        parser.add_argument('--evaluate_with_greedy_decoding', dest='evaluate_with_greedy_decoding', action='store_true')
        parser.add_argument('--evaluate_with_sampling', dest='evaluate_with_greedy_decoding', action='store_false')
        parser.set_defaults(evaluate_with_greedy_decoding=False)
        return parser
